import torch as t
import numpy as np
from util import inverseSrgb, srgb
from tqdm import trange

class PsfGuess(t.nn.Module):
  def __init__(self, w=32,h=64, lr=0.001):
    super(PsfGuess, self).__init__()
    self.a = t.nn.Conv2d(
      in_channels=1, out_channels=1,
      kernel_size=(h*2+1,w*2+1),
      padding=(h,w),
      bias=False,
    )
    self.optim = t.optim.SGD(self.a.parameters(), lr=lr)
    with t.no_grad():
      self.a.weight.zero_()
      self.a.weight[:,:,h,w]=1
    self.dims=(w,h)
  def forward(self,x):
    x=x.unsqueeze(0)
    x=self.a(x)
    return x.squeeze(0)
  def update(self, x,yplusbd, padval=0, iters=400):
    opt = t.optim.AdamW(self.a.parameters(), lr=0.001)
    for i in range(iters):
      opt.zero_grad()
      xp=t.nn.functional.pad(x, (self.dims[0],self.dims[0],self.dims[1],self.dims[1]), value=padval)
      ax = t.nn.functional.conv2d(
        xp.unsqueeze(0),weight=self.a.weight,
      ).squeeze(0)
      loss = t.sum((ax-yplusbd)**2)
      #print(loss)
      loss.backward()
      opt.step()
      #with t.no_grad():
      #  self.a.weight.clamp(0,1)
  def xmin(self, b, pval=0, iters=40, lr=0.04):
    print(b.shape)
    b=b.unsqueeze(0)
    a=self.a.weight.detach().squeeze(0).squeeze(0)
    x=b.clone()
    x.requires_grad = True;
    xoptiom = t.optim.AdamW([x], lr=lr)
    for i in trange(iters):
      xoptiom.zero_grad()
      xp=t.nn.functional.pad(x, (self.dims[0],self.dims[0],self.dims[1],self.dims[1]), value=pval)
      y=t.nn.functional.conv2d(xp, a.unsqueeze(0).unsqueeze(0))
      y.retain_grad()
      l=t.sum((y-b)**2)
      l.backward()
      xoptiom.step()
    return x.detach().squeeze(0)
  def getAlphaInv(self):
    return self.a.weight[:,:,self.dims[1],self.dims[0]].detach()

"""conv = t.nn.Conv2d(in_channels=1, out_channels=1, kernel_size=3, padding=(1,1), bias=False)

new_weights = t.tensor([[[[0, 0, 0], 
                          [0, 1, 0], 
                          [0, 0.1, 0.1]]]], dtype=t.float32)
with t.no_grad():
  conv.weight.copy_(new_weights)

m=PsfGuess(1,1,0.01)
for i in range(10):
  b=t.zeros((5,5))
  b[2,2]=3
  x=m.xmin(b)
  with t.no_grad():
    g=conv(x.unsqueeze(0)).squeeze(0)-b
  m.update(x,g,b)
  print(m.a.weight)"""
CANVSIZE=256
class PsfFinder:
  def __init__(self):
    self.beta = 0.01
    self.bg = t.ones((CANVSIZE,CANVSIZE),dtype = t.float32).to("cuda")
    self.bd = t.zeros((CANVSIZE,CANVSIZE), dtype=t.float32).to("cuda")
    self.bd[CANVSIZE//2-1:CANVSIZE//2+1,CANVSIZE//2-1:CANVSIZE//2+1] = 1
    pass
  def reset(self, beta=None):
    if(beta != None):
      self.beta = beta
    self.betainv = 1.0-self.beta
    self.b = self.beta*self.bg+self.bd*self.betainv
    self.m = PsfGuess(64,64).to("cuda")
  def get(self):
    a = self.m.xmin(self.b, self.beta)
    ret = np.zeros((2,CANVSIZE,CANVSIZE),dtype=np.float32)
    ret[0]=a.to("cpu").numpy()
    ret[1]=self.beta*self.bg.to("cpu").numpy()
    self.lastx = a
    return srgb(ret)
  def update(self, res):
    res=t.from_numpy(inverseSrgb(res).reshape(CANVSIZE,CANVSIZE)).to("cuda")
    yplusbd = res+self.m.getAlphaInv()*self.betainv*self.bd
    self.m.update(self.lastx,yplusbd.detach(), padval=self.beta)

if __name__ == "__main__":
  p = PsfFinder()
  p.reset()
  p.m.xmin(p.b)
    